import json
import logging
import os

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold, ShuffleSplit
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import index2mask, set_dataset_attr


def prepare_splits(dataset):
    """Ready train/val/test splits.

    Determine the type of split from the config and call the corresponding
    split generation / verification function.
    """
    # 从配置文件 cfg 中获取数据集的划分模式 split_mode，该模式决定了如何对数据集进行划分。
    # 例如，split_mode 可能是 standard、random 或 cv-5 等，代表标准划分、随机划分、或交叉验证的划分方式。
    # cfg.dataset.split_mode 说明 split_mode 是从配置文件 cfg 中的 dataset 部分读取的。
    split_mode = cfg.dataset.split_mode

    # 果 split_mode 是 standard，调用 setup_standard_split() 函数为数据集生成标准的训练、验证和测试集划分。
    # 标准划分通常意味着按照预定义的划分方式，固定地将数据集分成训练集、验证集和测试集。
    if split_mode == 'standard':
        setup_standard_split(dataset)
    # 如果 split_mode 是 random，调用 setup_random_split() 函数为数据集生成随机划分。
    # 随机划分会将数据集按一定比例随机分成训练集、验证集和测试集。
    # 随机划分的比例通常也会在配置文件中定义。
    elif split_mode == 'random':
        setup_random_split(dataset)
    # 如果 split_mode 以 'cv-' 开头，说明这是交叉验证（cross-validation）的划分模式。
    # 使用 split_mode.split('-')[1:] 分割字符串，提取交叉验证的类型 cv_type 和折数 k（通常是交叉验证的 K 折）。
    # setup_cv_split(dataset, cv_type, int(k)) 调用函数 setup_cv_split()，并传入数据集、交叉验证类型和折数，来进行交叉验证划分。
    elif split_mode.startswith('cv-'):
        cv_type, k = split_mode.split('-')[1:]
        setup_cv_split(dataset, cv_type, int(k))
    else:
        raise ValueError(f"Unknown split mode: {split_mode}")


def setup_standard_split(dataset):
    """Select a standard split.

    Use standard splits that come with the dataset. Pick one split based on the
    ``split_index`` from the config file if multiple splits are available.

    GNNBenchmarkDatasets have splits that are not prespecified as masks. Therefore,
    they are handled differently and are first processed to generate the masks.

    Raises:
        ValueError: If any one of train/val/test mask is missing.
        IndexError: If the ``split_index`` is greater or equal to the total
            number of splits available.
    """
    # 从配置文件 cfg 中读取数据集的 split_index（划分索引）和任务类型 task_level。
    # split_index 决定选择哪个划分（如果有多个划分的话）。
    # task_level 决定任务的类型，比如节点分类、图分类或链路预测等。
    split_index = cfg.dataset.split_index
    task_level = cfg.dataset.task

    # 如果任务类型是节点分类（task_level == 'node'），需要检查数据集中是否存在训练集、验证集和测试集的掩码（train_mask、val_mask、test_mask）。
    if task_level == 'node':
        for split_name in 'train_mask', 'val_mask', 'test_mask':
            # 使用 getattr(dataset.data, split_name, None) 来检查掩码是否存在。如果掩码缺失，抛出 ValueError。
            mask = getattr(dataset.data, split_name, None)
            # Check if the train/val/test split mask is available
            if mask is None:
                raise ValueError(f"Missing '{split_name}' for standard split")

            # Pick a specific split if multiple splits are available
            # 如果掩码是二维的（mask.dim() == 2），说明有多个划分，可以根据 split_index 选择一个划分。若 split_index 超出划分数量，抛出 IndexError。
            if mask.dim() == 2:
                if split_index >= mask.shape[1]:
                    raise IndexError(f"Specified split index ({split_index}) is "
                                     f"out of range of the number of available "
                                     f"splits ({mask.shape[1]}) for {split_name}")
                set_dataset_attr(dataset, split_name, mask[:, split_index],
                                 len(mask[:, split_index]))
            # 如果掩码是单一的（mask.dim() != 2），但 split_index 不为 0，抛出 IndexError，因为此时数据集只支持单一标准划分。
            else:
                if split_index != 0:
                    raise IndexError(f"This dataset has single standard split")

    elif task_level == 'graph':
        for split_name in 'train_graph_index', 'val_graph_index', 'test_graph_index':
            if not hasattr(dataset.data, split_name):
                raise ValueError(f"Missing '{split_name}' for standard split")
        if split_index != 0:
            raise NotImplementedError(f"Multiple standard splits not supported "
                                      f"for dataset task level: {task_level}")

    elif task_level == 'link_pred':
        for split_name in 'train_edge_index', 'val_edge_index', 'test_edge_index':
            if not hasattr(dataset.data, split_name):
                raise ValueError(f"Missing '{split_name}' for standard split")
        if split_index != 0:
            raise NotImplementedError(f"Multiple standard splits not supported "
                                      f"for dataset task level: {task_level}")

    else:
        if split_index != 0:
            raise NotImplementedError(f"Multiple standard splits not supported "
                                      f"for dataset task level: {task_level}")


def setup_random_split(dataset):
    """Generate random splits.

    Generate random train/val/test based on the ratios defined in the config
    file.

    Raises:
        ValueError: If the number split ratios is not equal to 3, or the ratios
            do not sum up to 1.
    """
    split_ratios = cfg.dataset.split

    if len(split_ratios) != 3:
        raise ValueError(
            f"Three split ratios is expected for train/val/test, received "
            f"{len(split_ratios)} split ratios: {repr(split_ratios)}")
    elif sum(split_ratios) != 1:
        raise ValueError(
            f"The train/val/test split ratios must sum up to 1, input ratios "
            f"sum up to {sum(split_ratios):.2f} instead: {repr(split_ratios)}")

    train_index, val_test_index = next(
        ShuffleSplit(
            train_size=split_ratios[0],
            random_state=cfg.seed
        ).split(dataset.data.y, dataset.data.y)
    )
    val_index, test_index = next(
        ShuffleSplit(
            train_size=split_ratios[1] / (1 - split_ratios[0]),
            random_state=cfg.seed
        ).split(dataset.data.y[val_test_index], dataset.data.y[val_test_index])
    )
    val_index = val_test_index[val_index]
    test_index = val_test_index[test_index]

    set_dataset_splits(dataset, [train_index, val_index, test_index])


def set_dataset_splits(dataset, splits):
    """Set given splits to the dataset object.

    Args:
        dataset: PyG dataset object
        splits: List of train/val/test split indices

    Raises:
        ValueError: If any pair of splits has intersecting indices
    """
    # First check whether splits intersect and raise error if so.
    for i in range(len(splits) - 1):
        for j in range(i + 1, len(splits)):
            n_intersect = len(set(splits[i]) & set(splits[j]))
            if n_intersect != 0:
                raise ValueError(
                    f"Splits must not have intersecting indices: "
                    f"split #{i} (n = {len(splits[i])}) and "
                    f"split #{j} (n = {len(splits[j])}) have "
                    f"{n_intersect} intersecting indices"
                )
    task_level = cfg.dataset.task
    if task_level == 'node':
        split_names = ['train_mask', 'val_mask', 'test_mask']
        for split_name, split_index in zip(split_names, splits):
            mask = index2mask(split_index, size=dataset.data.y.shape[0])
            set_dataset_attr(dataset, split_name, mask, len(mask))

    elif task_level == 'graph':
        split_names = [
            'train_graph_index', 'val_graph_index', 'test_graph_index'
        ]
        for split_name, split_index in zip(split_names, splits):
            set_dataset_attr(dataset, split_name, split_index, len(split_index))

    else:
        raise ValueError(f"Unsupported dataset task level: {task_level}")


def setup_cv_split(dataset, cv_type, k):
    """Generate cross-validation splits.

    Generate `k` folds for cross-validation based on `cv_type` procedure. Save
    these to disk or load existing splits, then select particular train/val/test
    split based on cfg.dataset.split_index from the config object.

    Args:
        dataset: PyG dataset object
        cv_type: Identifier for which sklearn fold splitter to use
        k: how many cross-validation folds to split the dataset into

    Raises:
        IndexError: If the `split_index` is greater than or equal to `k`
    """
    split_index = cfg.dataset.split_index
    split_dir = cfg.dataset.split_dir

    if split_index >= k:
        raise IndexError(f"Specified split_index={split_index} is "
                         f"out of range of the number of folds k={k}")

    os.makedirs(split_dir, exist_ok=True)
    save_file = os.path.join(
        split_dir,
        f"{cfg.dataset.format}_{dataset.name}_{cv_type}-{k}.json"
    )
    if not os.path.isfile(save_file):
        create_cv_splits(dataset, cv_type, k, save_file)
    with open(save_file) as f:
        cv = json.load(f)
    assert cv['dataset'] == dataset.name, "Unexpected dataset CV splits"
    assert cv['n_samples'] == len(dataset), "Dataset length does not match"
    assert cv['n_splits'] > split_index, "Fold selection out of range"
    assert k == cv['n_splits'], f"Expected k={k}, but {cv['n_splits']} found"

    test_ids = cv[str(split_index)]
    val_ids = cv[str((split_index + 1) % k)]
    train_ids = []
    for i in range(k):
        if i != split_index and i != (split_index + 1) % k:
            train_ids.extend(cv[str(i)])

    set_dataset_splits(dataset, [train_ids, val_ids, test_ids])


def create_cv_splits(dataset, cv_type, k, file_name):
    """Create cross-validation splits and save them to file.
    """
    n_samples = len(dataset)
    if cv_type == 'stratifiedkfold':
        kf = StratifiedKFold(n_splits=k, shuffle=True, random_state=123)
        kf_split = kf.split(np.zeros(n_samples), dataset.data.y)
    elif cv_type == 'kfold':
        kf = KFold(n_splits=k, shuffle=True, random_state=123)
        kf_split = kf.split(np.zeros(n_samples))
    else:
        ValueError(f"Unexpected cross-validation type: {cv_type}")

    splits = {'n_samples': n_samples,
              'n_splits': k,
              'cross_validator': kf.__str__(),
              'dataset': dataset.name
              }
    for i, (_, ids) in enumerate(kf_split):
        splits[i] = ids.tolist()
    with open(file_name, 'w') as f:
        json.dump(splits, f)
    logging.info(f"[*] Saved newly generated CV splits by {kf} to {file_name}")
